import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

# Import SpikingJelly from local folder (go up two levels from models/ to project root)
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'spikingjelly'))
from spikingjelly.activation_based import functional

from .spiking_cnn_layers import SpikingC2Linear, SpikingC2Conv, SpikingC2View, SpikingC2Pool, SpikingC2ConvStride1


class SpikingCNNModel(nn.Module):
    """Spiking CNN with readout - following v2 structure"""
    def __init__(self, args):
        super(SpikingCNNModel, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # Following original CNNModel structure exactly
        self.enc1 = SpikingC2Conv(args, args.num_chn, 128, 3)
        self.enc2 = SpikingC2Conv(args, 128, 128, 3)
        self.enc3 = SpikingC2Conv(args, 128, 256, 3)
        self.enc4 = SpikingC2Conv(args, 256, 256, 3)
        self.enc5 = SpikingC2Conv(args, 256, 512, 3)
        self.view = SpikingC2View((512, 1, 1), 512*1*1)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN" or args.dataset == "STL10_cls":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        elif args.dataset == "TinyImageNet":
            self.num_classes = 200
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.output = SpikingC2Linear(args, 512*1*1, output_dim)
        
        # Collect parameters (following original)
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.enc1, self.enc2, self.enc3, self.enc4, self.enc5, self.output]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for loss computation
            a1_spike, a1_prelif = self.enc1(x, detach_grad, return_prelif=True)
            a2_spike, a2_prelif = self.enc2(a1_spike, detach_grad, return_prelif=True)
            a3_spike, a3_prelif = self.enc3(a2_spike, detach_grad, return_prelif=True)
            a4_spike, a4_prelif = self.enc4(a3_spike, detach_grad, return_prelif=True)
            a5_spike, a5_prelif_raw = self.enc5(a4_spike, detach_grad, return_prelif=True)
            
            # View and output layers (no LIF activation)
            a6 = self.view(a5_spike, detach_grad)
            # Dynamic flatten for pre-LIF (handles different shapes automatically)
            if a5_prelif_raw is not None:
                T, N = a5_prelif_raw.shape[:2]
                a5_prelif = a5_prelif_raw.view(T, N, -1)  # Flatten all spatial and channel dims
            else:
                a5_prelif = None
            b = self.output(a6, detach_grad, act=False)  # NO LIF activation on output
            
            # Return pre-LIF features for loss computation
            prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = a5_spike.shape[:2]
                a5_spike_flat = a5_spike.view(T, N, -1)
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike_flat, b]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Forward pass - collect readout features for LIF layers, spikes for pooling
            a1_spike, a1_readout = self.enc1(x, detach_grad, return_readout=True)
            a2_spike, a2_readout = self.enc2(a1_spike, detach_grad, return_readout=True)
            a3_spike, a3_readout = self.enc3(a2_spike, detach_grad, return_readout=True)
            a4_spike, a4_readout = self.enc4(a3_spike, detach_grad, return_readout=True)
            a5_spike, a5_readout_raw = self.enc5(a4_spike, detach_grad, return_readout=True)
            
            # View and output layers (no readout)
            a6 = self.view(a5_spike, detach_grad)
            # Dynamic flatten for readout (handles expansion factor automatically)
            if a5_readout_raw is not None:
                T, N = a5_readout_raw.shape[:2]
                a5_readout = a5_readout_raw.view(T, N, -1)  # Flatten all spatial and channel dims
            else:
                a5_readout = None
            b = self.output(a6, detach_grad, act=False)  # NO LIF activation on output
            
            # Return readout features for loss computation
            readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = a5_spike.shape[:2]
                a5_spike_flat = a5_spike.view(T, N, -1)
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike_flat, b]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)  # [T, N, classes]
        
        if use_prelif_for_loss:
            # Reverse pass - collect pre-LIF features for loss computation
            c6_spike, c6_prelif = self.output.reverse(target, detach_grad, return_prelif=True)
            c5 = self.view.reverse(c6_spike, detach_grad)
            c4_spike, c4_prelif_raw = self.enc5.reverse(c5, detach_grad, return_prelif=True)
            c3_spike, c3_prelif = self.enc4.reverse(c4_spike, detach_grad, return_prelif=True)
            c2_spike, c2_prelif = self.enc3.reverse(c3_spike, detach_grad, return_prelif=True)
            c1_spike, c1_prelif = self.enc2.reverse(c2_spike, detach_grad, return_prelif=True)
            c0 = self.enc1.reverse(c1_spike, detach_grad, act=False)
            
            # Dynamic flatten for pre-LIF (handles different shapes automatically)
            if c4_prelif_raw is not None:
                T, N = c4_prelif_raw.shape[:2]
                c4_prelif = c4_prelif_raw.view(T, N, -1)  # Flatten all spatial and channel dims
            else:
                c4_prelif = None
            
            # Return pre-LIF features for loss computation
            prelif_features = [c0, c1_prelif, c2_prelif, c3_prelif, c4_prelif, c6_prelif, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = c4_spike.shape[:2]
                c4_spike_flat = c4_spike.view(T, N, -1)
                spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike_flat, c6_spike, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Following original reverse exactly
            c6_spike, c6_readout = self.output.reverse(target, detach_grad, return_readout=True)
            c5 = self.view.reverse(c6_spike, detach_grad)
            c4_spike, c4_readout = self.enc5.reverse(c5, detach_grad, return_readout=True)
            c3_spike, c3_readout = self.enc4.reverse(c4_spike, detach_grad, return_readout=True)
            c2_spike, c2_readout = self.enc3.reverse(c3_spike, detach_grad, return_readout=True)
            c1_spike, c1_readout = self.enc2.reverse(c2_spike, detach_grad, return_readout=True)
            c0 = self.enc1.reverse(c1_spike, detach_grad, act=False)
            
            # Return readout features for loss computation
            readout_features = [c0, c1_readout, c2_readout, c3_readout, c4_readout, c6_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = c4_spike.shape[:2]
                c4_spike_flat = c4_spike.view(T, N, -1)
                spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike_flat, c6_spike, target]
                return readout_features, spike_features
            
            return readout_features


# class SpikingCNNModel_Pool(nn.Module):
#     """Spiking CNN Pool with readout - following v2 structure"""
#     def __init__(self, args):
#         super(SpikingCNNModel_Pool, self).__init__()
        
#         # Store time steps
#         self.time_steps = args.time_steps
        
#         # Following original CNNModel_Pool structure exactly
#         self.enc1 = SpikingC2ConvStride1(args, args.num_chn, 128, 3)
#         self.pool1 = SpikingC2Pool(args, 2, 2, 0)
#         self.enc2 = SpikingC2ConvStride1(args, 128, 128, 3)
#         self.pool2 = SpikingC2Pool(args, 2, 2, 0)
#         self.enc3 = SpikingC2ConvStride1(args, 128, 256, 3)
#         self.pool3 = SpikingC2Pool(args, 2, 2, 0)
#         self.enc4 = SpikingC2ConvStride1(args, 256, 256, 3)
#         self.pool4 = SpikingC2Pool(args, 2, 2, 0)
#         self.enc5 = SpikingC2ConvStride1(args, 256, 512, 3)
#         self.pool5 = SpikingC2Pool(args, 2, 2, 0)
#         self.view = SpikingC2View((512, 1, 1), 512*1*1)
        
#         # Determine number of classes
#         if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
#             self.num_classes = 10
#         elif args.dataset == "CIFAR100":
#             self.num_classes = 100
        
#         # Determine output dimension based on label encoding configuration
#         self.use_label_encoding = getattr(args, 'use_label_encoding', False)
#         if self.use_label_encoding:
#             output_dim = getattr(args, 'encoding_dim', 128)
#         else:
#             output_dim = self.num_classes
        
#         self.output = SpikingC2Linear(args, 512*1*1, output_dim)
        
#         # Collect parameters (following original)
#         self.forward_params = list()
#         self.backward_params = list()
#         for layer in [self.enc1, self.enc2, self.enc3, self.enc4, self.enc5, self.output]:
#             forward_params, backward_params = layer.get_parameters()
#             self.forward_params += forward_params
#             self.backward_params += backward_params
    
#     def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
#         # Convert input to time-stepped format if needed
#         if len(x.shape) == 4:  # [N, C, H, W]
#             x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
#         if use_prelif_for_loss:
#             # Forward pass - collect pre-LIF features for loss computation
#             # New order: conv+bn → maxpool → lif (prelif records maxpool output)
            
#             # Layer 1: conv+bn → maxpool → record prelif → lif
#             a1_conv_bn = self.enc1(x, detach_grad, act=False)  # Only conv+bn
#             a1_prelif = self.pool1(a1_conv_bn, detach_grad)  # maxpool result as prelif
#             a1_spike = self.enc1.forward_lif(a1_prelif)  # Apply LIF to get spikes
            
#             # Layer 2: conv+bn → maxpool → record prelif → lif  
#             a2_conv_bn = self.enc2(a1_spike, detach_grad, act=False)  # Only conv+bn
#             a2_prelif = self.pool2(a2_conv_bn, detach_grad)  # maxpool result as prelif
#             a2_spike = self.enc2.forward_lif(a2_prelif)  # Apply LIF to get spikes
            
#             # Layer 3: conv+bn → maxpool → record prelif → lif
#             a3_conv_bn = self.enc3(a2_spike, detach_grad, act=False)  # Only conv+bn
#             a3_prelif = self.pool3(a3_conv_bn, detach_grad)  # maxpool result as prelif
#             a3_spike = self.enc3.forward_lif(a3_prelif)  # Apply LIF to get spikes
            
#             # Layer 4: conv+bn → maxpool → record prelif → lif
#             a4_conv_bn = self.enc4(a3_spike, detach_grad, act=False)  # Only conv+bn
#             a4_prelif = self.pool4(a4_conv_bn, detach_grad)  # maxpool result as prelif
#             a4_spike = self.enc4.forward_lif(a4_prelif)  # Apply LIF to get spikes
            
#             # Layer 5: conv+bn → maxpool → record prelif → lif
#             a5_conv_bn = self.enc5(a4_spike, detach_grad, act=False)  # Only conv+bn
#             a5_prelif_raw = self.pool5(a5_conv_bn, detach_grad)  # maxpool result as prelif
#             a5_spike = self.enc5.forward_lif(a5_prelif_raw)  # Apply LIF to get spikes
            
#             a5_viewed = self.view(a5_spike, detach_grad)
#             # View a5_prelif to match reverse shape  
#             if a5_prelif_raw is not None:
#                 T, N = a5_prelif_raw.shape[:2]
#                 a5_prelif_viewed = a5_prelif_raw.view(T, N, -1)  # Flatten all spatial and channel dims
#             else:
#                 a5_prelif_viewed = None
#             b = self.output(a5_viewed, detach_grad, act=False)
            
#             # Return pre-LIF features for loss computation
#             prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif_viewed, b]
            
#             # Also return spike features for sparsity statistics if requested
#             if return_spikes_for_stats:
#                 # Collect spike features corresponding to prelif_features - using spikes results
#                 spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_viewed, b]
#                 return prelif_features, spike_features
            
#             return prelif_features
#         else:
#             # Original mode: Use simplified approach since readout is typically disabled
#             # Match prelif mode - conv+bn → maxpool → lif
            
#             # Layer 1: conv+bn → maxpool → lif 
#             a1_conv_bn = self.enc1(x, detach_grad, act=False)  # Only conv+bn
#             a1_pooled = self.pool1(a1_conv_bn, detach_grad)  # maxpool
#             a1_spike = self.enc1.forward_lif(a1_pooled)  # Apply LIF to get spikes
#             a1_readout = None  # Readout typically disabled
            
#             # Layer 2: conv+bn → maxpool → lif
#             a2_conv_bn = self.enc2(a1_spike, detach_grad, act=False)  # Only conv+bn
#             a2_pooled = self.pool2(a2_conv_bn, detach_grad)  # maxpool
#             a2_spike = self.enc2.forward_lif(a2_pooled)  # Apply LIF to get spikes
#             a2_readout = None  # Readout typically disabled
            
#             # Layer 3: conv+bn → maxpool → lif
#             a3_conv_bn = self.enc3(a2_spike, detach_grad, act=False)  # Only conv+bn
#             a3_pooled = self.pool3(a3_conv_bn, detach_grad)  # maxpool
#             a3_spike = self.enc3.forward_lif(a3_pooled)  # Apply LIF to get spikes
#             a3_readout = None  # Readout typically disabled
            
#             # Layer 4: conv+bn → maxpool → lif
#             a4_conv_bn = self.enc4(a3_spike, detach_grad, act=False)  # Only conv+bn
#             a4_pooled = self.pool4(a4_conv_bn, detach_grad)  # maxpool
#             a4_spike = self.enc4.forward_lif(a4_pooled)  # Apply LIF to get spikes
#             a4_readout = None  # Readout typically disabled
            
#             # Layer 5: conv+bn → maxpool → lif
#             a5_conv_bn = self.enc5(a4_spike, detach_grad, act=False)  # Only conv+bn
#             a5_pooled = self.pool5(a5_conv_bn, detach_grad)  # maxpool
#             a5_spike = self.enc5.forward_lif(a5_pooled)  # Apply LIF to get spikes
#             a5_readout_raw = None  # Readout typically disabled
#             # Keep a5_readout_raw for later view operation
#             a5_readout_pooled = a5_readout_raw
            
#             a5 = self.view(a5_spike, detach_grad)
#             # Keep a5_readout as None since readout is typically disabled
#             a5_readout = None
#             b = self.output(a5, detach_grad, act=False)
            
#             # Return readout features for loss computation
#             readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, b]
            
#             # Also return spike features for sparsity statistics if requested
#             if return_spikes_for_stats:
#                 # Collect spike features directly since pool is already applied
#                 spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5, b]
#                 return readout_features, spike_features
            
#             return readout_features
    
#     def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
#         # Handle target format with label encoding support
#         if self.use_label_encoding:
#             # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
#             if len(target.shape) == 3 and target.shape[0] == self.time_steps:
#                 # Already in correct [T, num_classes, L] format
#                 pass
#             else:
#                 raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
#         else:
#             # Original mode: convert from label indices to one-hot
#             if len(target.shape) == 1:
#                 target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
#             # Convert to time-stepped format
#             if len(target.shape) == 2:
#                 target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)
        
#         if use_prelif_for_loss:
#             # Reverse pass - collect pre-LIF features for loss computation
#             c6_spike, c6_prelif = self.output.reverse(target, detach_grad, return_prelif=True)
#             c5 = self.view.reverse(c6_spike, detach_grad)
            
#             # New order: conv+bn → upsample → lif (matching forward order)
#             # Layer 5: conv+bn → upsample → record prelif → lif
#             c4_conv_bn = self.enc5.reverse(c5, detach_grad, act=False)  # Only conv+bn
#             c4_prelif = self.pool5.reverse(c4_conv_bn, detach_grad)  # upsample result as prelif
#             c4_spike = self.enc5.backward_lif(c4_prelif)  # Apply LIF to get spikes
            
#             # Layer 4: conv+bn → upsample → record prelif → lif
#             c3_conv_bn = self.enc4.reverse(c4_spike, detach_grad, act=False)  # Only conv+bn
#             c3_prelif = self.pool4.reverse(c3_conv_bn, detach_grad)  # upsample result as prelif
#             c3_spike = self.enc4.backward_lif(c3_prelif)  # Apply LIF to get spikes
            
#             # Layer 3: conv+bn → upsample → record prelif → lif
#             c2_conv_bn = self.enc3.reverse(c3_spike, detach_grad, act=False)  # Only conv+bn
#             c2_prelif = self.pool3.reverse(c2_conv_bn, detach_grad)  # upsample result as prelif
#             c2_spike = self.enc3.backward_lif(c2_prelif)  # Apply LIF to get spikes
            
#             # Layer 2: conv+bn → upsample → record prelif → lif
#             c1_conv_bn = self.enc2.reverse(c2_spike, detach_grad, act=False)  # Only conv+bn
#             c1_prelif = self.pool2.reverse(c1_conv_bn, detach_grad)  # upsample result as prelif
#             c1_spike = self.enc2.backward_lif(c1_prelif)  # Apply LIF to get spikes
            
#             # Layer 1: conv+bn → upsample → no lif (act=False at end)
#             c0_conv_bn = self.enc1.reverse(c1_spike, detach_grad, act=False)  # Only conv+bn
#             c0 = self.pool1.reverse(c0_conv_bn, detach_grad)  # upsample to original input size
            
#             # Return pre-LIF features for loss computation
#             prelif_features = [c0, c1_prelif, c2_prelif, c3_prelif, c4_prelif, c6_prelif, target]
            
#             # Also return spike features for sparsity statistics if requested
#             if return_spikes_for_stats:
#                 # Collect spike features corresponding to prelif_features
#                 # forward: [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_viewed, b]
#                 # reverse: [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, target]
#                 spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, target]
#                 return prelif_features, spike_features
            
#             return prelif_features
#         else:
#             # Original mode: Match prelif mode - conv+bn → upsample → lif (with readout disabled)
#             c6_spike, c6_readout = self.output.reverse(target, detach_grad, return_readout=True)
#             c5 = self.view.reverse(c6_spike, detach_grad)
            
#             # New order: conv+bn → upsample → lif (readout typically disabled)
#             # Layer 5: conv+bn → upsample → lif
#             c4_conv_bn = self.enc5.reverse(c5, detach_grad, act=False)  # Only conv+bn
#             c4_upsampled = self.pool5.reverse(c4_conv_bn, detach_grad)  # upsample
#             c4_spike = self.enc5.backward_lif(c4_upsampled)  # Apply LIF to get spikes
#             c4_readout = None  # Readout typically disabled
            
#             # Layer 4: conv+bn → upsample → lif
#             c3_conv_bn = self.enc4.reverse(c4_spike, detach_grad, act=False)  # Only conv+bn
#             c3_upsampled = self.pool4.reverse(c3_conv_bn, detach_grad)  # upsample
#             c3_spike = self.enc4.backward_lif(c3_upsampled)  # Apply LIF to get spikes
#             c3_readout = None  # Readout typically disabled
            
#             # Layer 3: conv+bn → upsample → lif
#             c2_conv_bn = self.enc3.reverse(c3_spike, detach_grad, act=False)  # Only conv+bn
#             c2_upsampled = self.pool3.reverse(c2_conv_bn, detach_grad)  # upsample
#             c2_spike = self.enc3.backward_lif(c2_upsampled)  # Apply LIF to get spikes
#             c2_readout = None  # Readout typically disabled
            
#             # Layer 2: conv+bn → upsample → lif
#             c1_conv_bn = self.enc2.reverse(c2_spike, detach_grad, act=False)  # Only conv+bn
#             c1_upsampled = self.pool2.reverse(c1_conv_bn, detach_grad)  # upsample
#             c1_spike = self.enc2.backward_lif(c1_upsampled)  # Apply LIF to get spikes
#             c1_readout = None  # Readout typically disabled
            
#             # Layer 1: conv+bn → upsample → no lif (act=False at end)
#             c0_conv_bn = self.enc1.reverse(c1_spike, detach_grad, act=False)  # Only conv+bn
#             c0 = self.pool1.reverse(c0_conv_bn, detach_grad)  # upsample to original input size
            
#             # Return readout features for loss computation
#             readout_features = [c0, c1_readout, c2_readout, c3_readout, c4_readout, c6_readout, target]
            
#             # Also return spike features for sparsity statistics if requested
#             if return_spikes_for_stats:
#                 # Collect spike features corresponding to readout_features, matching new order
#                 spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, target]
#                 return readout_features, spike_features
            
#             return readout_features

class SpikingCNNModel_Pool(nn.Module):
    """Spiking CNN Pool with readout - following v2 structure"""
    def __init__(self, args):
        super(SpikingCNNModel_Pool, self).__init__()

        # Store time steps
        self.time_steps = args.time_steps

        # Following original CNNModel_Pool structure exactly
        self.enc1 = SpikingC2ConvStride1(args, args.num_chn, 128, 3)
        self.pool1 = SpikingC2Pool(args, 2, 2, 0)
        self.enc2 = SpikingC2ConvStride1(args, 128, 128, 3)
        self.pool2 = SpikingC2Pool(args, 2, 2, 0)
        self.enc3 = SpikingC2ConvStride1(args, 128, 256, 3)
        self.pool3 = SpikingC2Pool(args, 2, 2, 0)
        self.enc4 = SpikingC2ConvStride1(args, 256, 256, 3)
        self.pool4 = SpikingC2Pool(args, 2, 2, 0)
        self.enc5 = SpikingC2ConvStride1(args, 256, 512, 3)
        self.pool5 = SpikingC2Pool(args, 2, 2, 0)

        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        elif args.dataset == "TinyImageNet":
            self.num_classes = 200

        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes

        # Determine final spatial size based on input image size
        # 5 pooling layers with stride=2: spatial_size = input_size / 2^5
        if args.dataset == "TinyImageNet":
            final_spatial_size = 2  # 64 / 32 = 2
        else:
            final_spatial_size = 1  # 32 / 32 = 1 (CIFAR-10, CIFAR-100, etc.)

        self.view = SpikingC2View((512, final_spatial_size, final_spatial_size), 512*final_spatial_size*final_spatial_size)
        self.output = SpikingC2Linear(args, 512*final_spatial_size*final_spatial_size, output_dim)

        # Collect parameters (following original)
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.enc1, self.enc2, self.enc3, self.enc4, self.enc5, self.output]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False, inference_only=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]

        # Lightweight inference mode - only compute spike features, no readout
        if inference_only:
            # Minimal forward pass - no intermediate feature storage for loss
            a1 = self.enc1(x, detach_grad)  # Only return spikes
            a1 = self.pool1(a1, detach_grad)
            a2 = self.enc2(a1, detach_grad)
            a2 = self.pool2(a2, detach_grad)
            a3 = self.enc3(a2, detach_grad)
            a3 = self.pool3(a3, detach_grad)
            a4 = self.enc4(a3, detach_grad)
            a4 = self.pool4(a4, detach_grad)
            a5 = self.enc5(a4, detach_grad)
            a5 = self.pool5(a5, detach_grad)
            a5 = self.view(a5, detach_grad)
            b = self.output(a5, detach_grad, act=False)

            # Return only final output (no intermediate features)
            return b

        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for loss computation
            a1_spike, a1_prelif_raw = self.enc1(x, detach_grad, return_prelif=True)
            a1_pooled = self.pool1(a1_spike, detach_grad)
            # Keep time dimension for pre-LIF
            if a1_prelif_raw is not None:
                a1_prelif = self.pool1(a1_prelif_raw, detach_grad)  # Still has time dimension
            else:
                a1_prelif = None
            
            a2_spike, a2_prelif_raw = self.enc2(a1_pooled, detach_grad, return_prelif=True)
            a2_pooled = self.pool2(a2_spike, detach_grad)
            if a2_prelif_raw is not None:
                a2_prelif = self.pool2(a2_prelif_raw, detach_grad)
            else:
                a2_prelif = None
            
            a3_spike, a3_prelif_raw = self.enc3(a2_pooled, detach_grad, return_prelif=True)
            a3_pooled = self.pool3(a3_spike, detach_grad)
            if a3_prelif_raw is not None:
                a3_prelif = self.pool3(a3_prelif_raw, detach_grad)
            else:
                a3_prelif = None
            
            a4_spike, a4_prelif_raw = self.enc4(a3_pooled, detach_grad, return_prelif=True)
            a4_pooled = self.pool4(a4_spike, detach_grad)
            if a4_prelif_raw is not None:
                a4_prelif = self.pool4(a4_prelif_raw, detach_grad)
            else:
                a4_prelif = None
            
            a5_spike, a5_prelif_raw = self.enc5(a4_pooled, detach_grad, return_prelif=True)
            a5_pooled = self.pool5(a5_spike, detach_grad)
            if a5_prelif_raw is not None:
                a5_prelif_pooled = self.pool5(a5_prelif_raw, detach_grad)
            else:
                a5_prelif_pooled = None
            
            a5_viewed = self.view(a5_pooled, detach_grad)
            # Keep a5_prelif in its original pooled form to match reverse
            a5_prelif = self.view(a5_prelif_pooled, detach_grad)
            b = self.output(a5_viewed, detach_grad, act=False)
            
            # Return pre-LIF features for loss computation
            prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Collect spike features corresponding to prelif_features - directly using already computed pooled results
                spike_features = [x, a1_pooled, a2_pooled, a3_pooled, a4_pooled, a5_viewed, b]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Following original forward exactly
            a1_spike, a1_readout_raw = self.enc1(x, detach_grad, return_readout=True)
            a1 = self.pool1(a1_spike, detach_grad)
            # Keep time dimension for readout, let main function handle averaging
            if a1_readout_raw is not None:
                a1_readout = self.pool1(a1_readout_raw, detach_grad)  # Still has time dimension
            else:
                a1_readout = None
            
            a2_spike, a2_readout_raw = self.enc2(a1, detach_grad, return_readout=True)
            a2 = self.pool2(a2_spike, detach_grad)
            if a2_readout_raw is not None:
                a2_readout = self.pool2(a2_readout_raw, detach_grad)
            else:
                a2_readout = None
            
            a3_spike, a3_readout_raw = self.enc3(a2, detach_grad, return_readout=True)
            a3 = self.pool3(a3_spike, detach_grad)
            if a3_readout_raw is not None:
                a3_readout = self.pool3(a3_readout_raw, detach_grad)
            else:
                a3_readout = None
            
            a4_spike, a4_readout_raw = self.enc4(a3, detach_grad, return_readout=True)
            a4 = self.pool4(a4_spike, detach_grad)
            if a4_readout_raw is not None:
                a4_readout = self.pool4(a4_readout_raw, detach_grad)
            else:
                a4_readout = None
            
            a5_spike, a5_readout_raw = self.enc5(a4, detach_grad, return_readout=True)
            a5 = self.pool5(a5_spike, detach_grad)
            if a5_readout_raw is not None:
                a5_readout_pooled = self.pool5(a5_readout_raw, detach_grad)
            else:
                a5_readout_pooled = None
            
            a5 = self.view(a5, detach_grad)
            # Keep a5_readout in its original pooled form to match reverse
            a5_readout = a5_readout_pooled
            b = self.output(a5, detach_grad, act=False)
            
            # Return readout features for loss computation
            readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Collect spike features after pooling, keep original form
                a1_spike_pooled = self.pool1(a1_spike, detach_grad)
                a2_spike_pooled = self.pool2(a2_spike, detach_grad)
                a3_spike_pooled = self.pool3(a3_spike, detach_grad)
                a4_spike_pooled = self.pool4(a4_spike, detach_grad)
                a5_spike_pooled = self.pool5(a5_spike, detach_grad)
                a5_spike_viewed = self.view(a5_spike_pooled, detach_grad)
                
                spike_features = [x, a1_spike_pooled, a2_spike_pooled, a3_spike_pooled, a4_spike_pooled, a5_spike_viewed, b]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)
        
        if use_prelif_for_loss:
            # Reverse pass - collect pre-LIF features for loss computation
            c6_spike, c6_prelif = self.output.reverse(target, detach_grad, return_prelif=True)
            c5 = self.view.reverse(c6_spike, detach_grad)
            c4 = self.pool5.reverse(c5, detach_grad)
            c4_spike, c4_prelif = self.enc5.reverse(c4, detach_grad, return_prelif=True)
            c3 = self.pool4.reverse(c4_spike, detach_grad)
            c3_spike, c3_prelif = self.enc4.reverse(c3, detach_grad, return_prelif=True)
            c2 = self.pool3.reverse(c3_spike, detach_grad)
            c2_spike, c2_prelif = self.enc3.reverse(c2, detach_grad, return_prelif=True)
            c1 = self.pool2.reverse(c2_spike, detach_grad)
            c1_spike, c1_prelif = self.enc2.reverse(c1, detach_grad, return_prelif=True)
            c0 = self.pool1.reverse(c1_spike, detach_grad)
            c0 = self.enc1.reverse(c0, detach_grad, act=False)
            
            # Return pre-LIF features for loss computation
            prelif_features = [c0, c1_prelif, c2_prelif, c3_prelif, c4_prelif, c6_prelif, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Collect spike features corresponding to prelif_features
                # forward: [x, a1_pooled, a2_pooled, a3_pooled, a4_pooled, a5_viewed, b]
                # reverse: [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, target]
                spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Following original reverse exactly  
            c6_spike, c6_readout = self.output.reverse(target, detach_grad, return_readout=True)
            c5 = self.view.reverse(c6_spike, detach_grad)
            c4 = self.pool5.reverse(c5, detach_grad)
            c4_spike, c4_readout = self.enc5.reverse(c4, detach_grad, return_readout=True)
            c3 = self.pool4.reverse(c4_spike, detach_grad)
            c3_spike, c3_readout = self.enc4.reverse(c3, detach_grad, return_readout=True)
            c2 = self.pool3.reverse(c3_spike, detach_grad)
            c2_spike, c2_readout = self.enc3.reverse(c2, detach_grad, return_readout=True)
            c1 = self.pool2.reverse(c2_spike, detach_grad)
            c1_spike, c1_readout = self.enc2.reverse(c1, detach_grad, return_readout=True)
            c0 = self.pool1.reverse(c1_spike, detach_grad)
            c0 = self.enc1.reverse(c0, detach_grad, act=False)
            
            # Return readout features for loss computation
            readout_features = [c0, c1_readout, c2_readout, c3_readout, c4_readout, c6_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Collect spike features corresponding to readout_features, preserving original shape
                spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike, c5, c6_spike, target]
                return readout_features, spike_features
            
            return readout_features



class SpikingCNNModel_3Layer(nn.Module):
    """3-layer Spiking CNN with readout - following v2 structure"""
    def __init__(self, args):
        super(SpikingCNNModel_3Layer, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # 3-layer ultra-wide architecture without pooling (4x width for maximum parameters)
        self.enc1 = SpikingC2Conv(args, args.num_chn, 512, 3)   # 3 -> 512 (4x original)
        self.enc2 = SpikingC2Conv(args, 512, 1024, 3)           # 512 -> 1024 (4x original) 
        self.enc3 = SpikingC2Conv(args, 1024, 2048, 3)          # 1024 -> 2048 (4x original)
        self.view = SpikingC2View((2048, 4, 4), 2048*4*4)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.output = SpikingC2Linear(args, 2048*4*4, output_dim)
        
        # Collect parameters (following original)
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.enc1, self.enc2, self.enc3, self.output]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        # Clear forward pass like CNN Pool
        a1_spike, a1_readout_raw = self.enc1(x, detach_grad, return_readout=True)
        a2_spike, a2_readout_raw = self.enc2(a1_spike, detach_grad, return_readout=True)
        a3_spike, a3_readout_raw = self.enc3(a2_spike, detach_grad, return_readout=True)
        
        a3_viewed = self.view(a3_spike, detach_grad)
        b = self.output(a3_viewed, detach_grad, act=False)  # NO LIF activation on output
        
        # Return features for loss computation - readout if available, otherwise spikes
        if a1_readout_raw is not None:  # Using readout system
            T, N = a1_readout_raw.shape[:2]
            a1_feature = a1_readout_raw.view(T, N, -1)
            T, N = a2_readout_raw.shape[:2]
            a2_feature = a2_readout_raw.view(T, N, -1)
            T, N = a3_readout_raw.shape[:2]
            a3_feature = a3_readout_raw.view(T, N, -1)
        else:  # Using spike system directly
            T, N = a1_spike.shape[:2]
            a1_feature = a1_spike.view(T, N, -1)  # Flatten spatial dims
            T, N = a2_spike.shape[:2]
            a2_feature = a2_spike.view(T, N, -1)
            T, N = a3_spike.shape[:2]
            a3_feature = a3_spike.view(T, N, -1)
            
        # Return features for loss computation (readout or spikes)
        readout_features = [x, a1_feature, a2_feature, a3_feature, b]
        
        # Also return spike features for sparsity statistics if requested
        if return_spikes_for_stats:
            # Flatten spike features for statistics
            T, N = a3_spike.shape[:2]
            a3_spike_flat = a3_spike.view(T, N, -1)
            spike_features = [x, a1_spike, a2_spike, a3_spike_flat, b]
            return readout_features, spike_features
        
        return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)  # [T, N, classes]
        
        # Clear reverse pass like CNN Pool
        c3_spike, c3_readout = self.output.reverse(target, detach_grad, return_readout=True)
        c2 = self.view.reverse(c3_spike, detach_grad)
        
        c2_spike, c2_readout_raw = self.enc3.reverse(c2, detach_grad, return_readout=True)
        c1_spike, c1_readout_raw = self.enc2.reverse(c2_spike, detach_grad, return_readout=True)
        c0 = self.enc1.reverse(c1_spike, detach_grad, act=False)
        
        # Return features for loss computation - readout if available, otherwise spikes
        if c1_readout_raw is not None:  # Using readout system
            T, N = c1_readout_raw.shape[:2]
            c1_feature = c1_readout_raw.view(T, N, -1)
            T, N = c2_readout_raw.shape[:2]
            c2_feature = c2_readout_raw.view(T, N, -1)
            c3_feature = c3_readout  # output.reverse readout corresponds to a3_feature
        else:  # Using spike system directly
            T, N = c1_spike.shape[:2]
            c1_feature = c1_spike.view(T, N, -1)  # Flatten spatial dims
            T, N = c2_spike.shape[:2]
            c2_feature = c2_spike.view(T, N, -1)
            T, N = c3_spike.shape[:2]
            c3_feature = c3_spike.view(T, N, -1)  # output.reverse spike corresponds to a3_feature

        # Return features for loss computation (readout or spikes) - 5 features
        readout_features = [c0, c1_feature, c2_feature, c3_feature, target]
        
        # Also return spike features for sparsity statistics if requested
        if return_spikes_for_stats:
            # Flatten spike features for statistics
            T, N = c2_spike.shape[:2]
            c2_spike_flat = c2_spike.view(T, N, -1)
            spike_features = [c0, c1_spike, c2_spike_flat, c3_spike, target]
            return readout_features, spike_features
        
        return readout_features


class SpikingCNNModel_WideShallow(nn.Module):
    """Three-layer CNN with 8x width - ultra-wide shallow convolution design"""
    def __init__(self, args):
        super(SpikingCNNModel_WideShallow, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # Three-layer CNN architecture with 8x width (4x of doubled)
        # Conv1: 3×3, stride=1, channels=512, then MaxPool 2×2
        self.enc1 = SpikingC2ConvStride1(args, args.num_chn, 512, 3)  # 32×32 → 32×32
        self.pool1 = SpikingC2Pool(args, 2, 2, 0)                     # 32×32 → 16×16
        
        # Conv2: 3×3, stride=1, channels=1024, then MaxPool 2×2
        self.enc2 = SpikingC2ConvStride1(args, 512, 1024, 3)          # 16×16 → 16×16
        self.pool2 = SpikingC2Pool(args, 2, 2, 0)                     # 16×16 → 8×8
        
        # Conv3: 3×3, stride=1, channels=2048, then MaxPool 2×2
        self.enc3 = SpikingC2ConvStride1(args, 1024, 2048, 3)         # 8×8 → 8×8
        self.pool3 = SpikingC2Pool(args, 2, 2, 0)                     # 8×8 → 4×4
        
        self.view = SpikingC2View((2048, 4, 4), 2048*4*4)             # flatten: 32768 features
        
        # Full connection layer with 8192 neurons
        self.fc = SpikingC2Linear(args, 2048*4*4, 8192)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.output = SpikingC2Linear(args, 8192, output_dim)
        
        # Collect parameters (three-layer CNN with maxpool design)
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.enc1, self.enc2, self.enc3, self.fc, self.output]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for loss computation
            # Conv1 + Pool1
            a1_spike, a1_prelif = self.enc1(x, detach_grad, return_prelif=True)
            a1_pooled = self.pool1(a1_spike, detach_grad)
            if a1_prelif is not None:
                a1_prelif_pooled = self.pool1(a1_prelif, detach_grad)
            else:
                a1_prelif_pooled = None
            
            # Conv2 + Pool2
            a2_spike, a2_prelif = self.enc2(a1_pooled, detach_grad, return_prelif=True)
            a2_pooled = self.pool2(a2_spike, detach_grad)
            if a2_prelif is not None:
                a2_prelif_pooled = self.pool2(a2_prelif, detach_grad)
            else:
                a2_prelif_pooled = None
            
            # Conv3 + Pool3
            a3_spike, a3_prelif = self.enc3(a2_pooled, detach_grad, return_prelif=True)
            a3_pooled = self.pool3(a3_spike, detach_grad)
            if a3_prelif is not None:
                a3_prelif_pooled = self.pool3(a3_prelif, detach_grad)
            else:
                a3_prelif_pooled = None
            
            # View and FC layers
            a4 = self.view(a3_pooled, detach_grad)  # a4 is flatten result [32768]
            a5_spike, a5_prelif = self.fc(a4, detach_grad, return_prelif=True)
            
            # Output layer (no LIF activation)
            b = self.output(a5_spike, detach_grad, act=False)
            
            # Return pre-LIF features for loss computation (6 features)
            # [x, a1(conv1+pool1), a2(conv2+pool2), a3(conv3+pool3+flatten), a4(fc), output]
            prelif_features = [x, a1_prelif_pooled, a2_prelif_pooled, a4, a5_prelif, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = a3_pooled.shape[:2]
                a3_spike_flat = a3_pooled.view(T, N, -1)
                spike_features = [x, a1_pooled, a2_pooled, a3_spike_flat, a5_spike, b]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Forward pass - collect readout features for LIF layers
            # Conv1 + Pool1
            a1_spike, a1_readout = self.enc1(x, detach_grad, return_readout=True)
            a1_pooled = self.pool1(a1_spike, detach_grad)
            if a1_readout is not None:
                a1_readout_pooled = self.pool1(a1_readout, detach_grad)
            else:
                a1_readout_pooled = None
            
            # Conv2 + Pool2
            a2_spike, a2_readout = self.enc2(a1_pooled, detach_grad, return_readout=True)
            a2_pooled = self.pool2(a2_spike, detach_grad)
            if a2_readout is not None:
                a2_readout_pooled = self.pool2(a2_readout, detach_grad)
            else:
                a2_readout_pooled = None
            
            # Conv3 + Pool3
            a3_spike, a3_readout = self.enc3(a2_pooled, detach_grad, return_readout=True)
            a3_pooled = self.pool3(a3_spike, detach_grad)
            if a3_readout is not None:
                a3_readout_pooled = self.pool3(a3_readout, detach_grad)
            else:
                a3_readout_pooled = None
            
            # View and FC layers
            a4 = self.view(a3_pooled, detach_grad)
            a5_spike, a5_readout = self.fc(a4, detach_grad, return_readout=True)
            
            # Output layer (no LIF activation)
            b = self.output(a5_spike, detach_grad, act=False)
            
            # Return readout features for loss computation
            readout_features = [x, a1_readout_pooled, a2_readout_pooled, a3_readout_pooled, a5_readout, b]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = a3_pooled.shape[:2]
                a3_spike_flat = a3_pooled.view(T, N, -1)
                spike_features = [x, a1_pooled, a2_pooled, a3_spike_flat, a5_spike, b]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)  # [T, N, classes]
        
        if use_prelif_for_loss:
            # Reverse pass using the offset correspondence pattern
            # Forward: [x, a1(conv1+pool1), a2(conv2+pool2), a3(conv3+pool3+flatten), a4(fc), output]
            # Reverse: [c0, c1_prelif, c2_prelif, c3(view.reverse), c4_prelif, target]
            
            # 1. Output reverse: target -> c4
            c4_spike, c4_prelif = self.output.reverse(target, detach_grad, return_prelif=True)
            
            # 2. FC reverse: c4 -> c3 
            c3_spike, c3_prelif = self.fc.reverse(c4_spike, detach_grad, return_prelif=True)
            
            # 3. View reverse: c3 -> c3_viewed (flattened)
            c3_viewed = self.view.reverse(c3_spike, detach_grad)
            
            # 4. Pool3 reverse + Conv3 reverse: c3_viewed -> c2
            c3_unpooled = self.pool3.reverse(c3_viewed, detach_grad) 
            c2_spike, c2_prelif = self.enc3.reverse(c3_unpooled, detach_grad, return_prelif=True)
            
            # 5. Pool2 reverse + Conv2 reverse: c2 -> c1
            c2_unpooled = self.pool2.reverse(c2_spike, detach_grad)
            c1_spike, c1_prelif = self.enc2.reverse(c2_unpooled, detach_grad, return_prelif=True)
            
            # 6. Pool1 reverse + Conv1 reverse: c1 -> c0
            c1_unpooled = self.pool1.reverse(c1_spike, detach_grad)
            c0 = self.enc1.reverse(c1_unpooled, detach_grad, act=False)
            
            # Return 6 features to match forward [x, a1_prelif_pooled, a2_prelif_pooled, a4, a5_prelif, b]
            # Reverse features: [c0, c1_prelif, c2_prelif, c3_spike, c4_prelif, target]
            # Note: c3_spike is the output of FC.reverse, corresponding to forward's a4 (flatten result)
            prelif_features = [c0, c1_prelif, c2_prelif, c3_spike, c4_prelif, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Use spike versions of the same features
                spike_features = [c0, c1_spike, c2_spike, c3_spike, c4_spike, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: Following original reverse with maxpool
            # Output -> FC
            c5_spike, c5_readout = self.output.reverse(target, detach_grad, return_readout=True)
            c4_spike, c4_readout = self.fc.reverse(c5_spike, detach_grad, return_readout=True)
            
            # View -> Pool3 -> Conv3
            c3 = self.view.reverse(c4_spike, detach_grad)
            c3_unpooled = self.pool3.reverse(c3, detach_grad)
            c2_spike, c2_readout = self.enc3.reverse(c3_unpooled, detach_grad, return_readout=True)
            
            # Pool2 -> Conv2
            c2_unpooled = self.pool2.reverse(c2_spike, detach_grad)
            c1_spike, c1_readout = self.enc2.reverse(c2_unpooled, detach_grad, return_readout=True)
            
            # Pool1 -> Conv1
            c1_unpooled = self.pool1.reverse(c1_spike, detach_grad)
            c0 = self.enc1.reverse(c1_unpooled, detach_grad, act=False)
            
            # Return readout features for loss computation
            readout_features = [c0, c1_readout, c2_readout, c4_readout, c5_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                # Flatten spike features for statistics
                T, N = c2_spike.shape[:2]
                c2_spike_flat = c2_spike.view(T, N, -1)
                spike_features = [c0, c1_spike, c2_spike_flat, c4_spike, c5_spike, target]
                return readout_features, spike_features
            
            return readout_features


class SpikingCNNModel_VGG16(nn.Module):
    """Spiking VGG-16 with readout - following standard VGG-16 structure"""
    def __init__(self, args):
        super(SpikingCNNModel_VGG16, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # VGG-16 structure:
        # Block 1: 2x Conv3x3(64) + MaxPool
        self.conv1_1 = SpikingC2ConvStride1(args, args.num_chn, 64, 3)
        self.conv1_2 = SpikingC2ConvStride1(args, 64, 64, 3)
        self.pool1 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 2: 2x Conv3x3(128) + MaxPool
        self.conv2_1 = SpikingC2ConvStride1(args, 64, 128, 3)
        self.conv2_2 = SpikingC2ConvStride1(args, 128, 128, 3)
        self.pool2 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 3: 3x Conv3x3(256) + MaxPool
        self.conv3_1 = SpikingC2ConvStride1(args, 128, 256, 3)
        self.conv3_2 = SpikingC2ConvStride1(args, 256, 256, 3)
        self.conv3_3 = SpikingC2ConvStride1(args, 256, 256, 3)
        self.pool3 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 4: 3x Conv3x3(512) + MaxPool
        self.conv4_1 = SpikingC2ConvStride1(args, 256, 512, 3)
        self.conv4_2 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.conv4_3 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.pool4 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 5: 3x Conv3x3(512) + MaxPool
        self.conv5_1 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.conv5_2 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.conv5_3 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.pool5 = SpikingC2Pool(args, 2, 2, 0)
        
        # View and fully connected layers
        self.view = SpikingC2View((512, 1, 1), 512*1*1)
        self.fc1 = SpikingC2Linear(args, 512*1*1, 4096)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.fc2 = SpikingC2Linear(args, 4096, output_dim)
        
        # Collect parameters
        self.forward_params = list()
        self.backward_params = list()
        
        conv_layers = [
            self.conv1_1, self.conv1_2, self.conv2_1, self.conv2_2,
            self.conv3_1, self.conv3_2, self.conv3_3,
            self.conv4_1, self.conv4_2, self.conv4_3,
            self.conv5_1, self.conv5_2, self.conv5_3,
            self.fc1, self.fc2
        ]
        
        for layer in conv_layers:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for EVERY learnable layer (15 layers total)
            # IMPORTANT: Apply pooling to prelif features that need it to match reverse spatial dimensions
            
            # Layer 1: conv1_1
            a1_spike, a1_prelif = self.conv1_1(x, detach_grad, return_prelif=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_prelif_raw = self.conv1_2(a1_spike, detach_grad, return_prelif=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a2_prelif = self.pool1(a2_prelif_raw, detach_grad) if a2_prelif_raw is not None else None
            
            # Layer 3: conv2_1
            a3_spike, a3_prelif = self.conv2_1(a2_pooled, detach_grad, return_prelif=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_prelif_raw = self.conv2_2(a3_spike, detach_grad, return_prelif=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a4_prelif = self.pool2(a4_prelif_raw, detach_grad) if a4_prelif_raw is not None else None
            
            # Layer 5: conv3_1
            a5_spike, a5_prelif = self.conv3_1(a4_pooled, detach_grad, return_prelif=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_prelif = self.conv3_2(a5_spike, detach_grad, return_prelif=True)
            
            # Layer 7: conv3_3
            a7_spike, a7_prelif_raw = self.conv3_3(a6_spike, detach_grad, return_prelif=True)
            a7_pooled = self.pool3(a7_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a7_prelif = self.pool3(a7_prelif_raw, detach_grad) if a7_prelif_raw is not None else None
            
            # Layer 8: conv4_1
            a8_spike, a8_prelif = self.conv4_1(a7_pooled, detach_grad, return_prelif=True)
            
            # Layer 9: conv4_2
            a9_spike, a9_prelif = self.conv4_2(a8_spike, detach_grad, return_prelif=True)
            
            # Layer 10: conv4_3
            a10_spike, a10_prelif_raw = self.conv4_3(a9_spike, detach_grad, return_prelif=True)
            a10_pooled = self.pool4(a10_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a10_prelif = self.pool4(a10_prelif_raw, detach_grad) if a10_prelif_raw is not None else None
            
            # Layer 11: conv5_1
            a11_spike, a11_prelif = self.conv5_1(a10_pooled, detach_grad, return_prelif=True)
            
            # Layer 12: conv5_2
            a12_spike, a12_prelif = self.conv5_2(a11_spike, detach_grad, return_prelif=True)
            
            # Layer 13: conv5_3
            a13_spike, a13_prelif_raw = self.conv5_3(a12_spike, detach_grad, return_prelif=True)
            a13_pooled = self.pool5(a13_spike, detach_grad)
            # Apply pool to prelif to match reverse spatial dimensions
            if a13_prelif_raw is not None:
                a13_prelif_pooled = self.pool5(a13_prelif_raw, detach_grad)
                # Flatten to match fc1.reverse's prelif dimension (4096)
                # But wait, fc1.reverse's prelif is 4096-dim, conv5_3's is 512-dim
                # This is the fundamental mismatch caused by the view layer
            else:
                a13_prelif_pooled = None
            
            # View layer (reshape to flatten)
            a13_viewed = self.view(a13_pooled, detach_grad)
            # Also apply view to conv5_3's prelif to get 512-dim flattened
            if a13_prelif_pooled is not None:
                T, N = a13_prelif_pooled.shape[:2]
                a13_prelif_viewed = a13_prelif_pooled.view(T, N, -1)  # [T, N, 512]
            else:
                a13_prelif_viewed = None
            
            # Layer 14: fc1
            a14_spike, a14_prelif = self.fc1(a13_viewed, detach_grad, return_prelif=True)
            
            # Layer 15: fc2 (output)
            a15 = self.fc2(a14_spike, detach_grad, act=False)
            
            # The issue: VGG-16 has a non-learnable view layer that creates a dimension mismatch
            # Forward: conv5_3 -> pool -> view(512) -> fc1(512->4096) -> fc2(4096->10)
            # Reverse: fc2.reverse(10->4096) -> fc1.reverse(4096->512) -> view.reverse -> pool.reverse -> conv5_3.reverse
            #
            # The correct correspondence for gradient flow:
            # Position 13: fc1's input (viewed conv5_3 output, 512-dim) to match fc1.reverse's output (512-dim)
            # Position 14: fc1's prelif (4096-dim) to match fc2.reverse's output (4096-dim)
            # Position 15: fc2's output (10-dim) to match target (10-dim)
            #
            # This way:
            # - conv5_3 gets gradient from position 13 loss term
            # - fc1.backward gets gradient from position 14 loss term  
            # - fc2.backward gets gradient from position 15 loss term
            
            prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif, a6_prelif, a7_prelif, 
                             a8_prelif, a9_prelif, a10_prelif, a11_prelif, a12_prelif, a13_prelif_viewed, a14_prelif, a15]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike, a7_spike,
                                a8_spike, a9_spike, a10_spike, a11_spike, a12_spike, a13_spike, a14_spike, a15]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for EVERY learnable layer (15 layers total)
            # Layer 1: conv1_1
            a1_spike, a1_readout = self.conv1_1(x, detach_grad, return_readout=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_readout = self.conv1_2(a1_spike, detach_grad, return_readout=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            
            # Layer 3: conv2_1
            a3_spike, a3_readout = self.conv2_1(a2_pooled, detach_grad, return_readout=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_readout = self.conv2_2(a3_spike, detach_grad, return_readout=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            
            # Layer 5: conv3_1
            a5_spike, a5_readout = self.conv3_1(a4_pooled, detach_grad, return_readout=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_readout = self.conv3_2(a5_spike, detach_grad, return_readout=True)
            
            # Layer 7: conv3_3
            a7_spike, a7_readout = self.conv3_3(a6_spike, detach_grad, return_readout=True)
            a7_pooled = self.pool3(a7_spike, detach_grad)
            
            # Layer 8: conv4_1
            a8_spike, a8_readout = self.conv4_1(a7_pooled, detach_grad, return_readout=True)
            
            # Layer 9: conv4_2
            a9_spike, a9_readout = self.conv4_2(a8_spike, detach_grad, return_readout=True)
            
            # Layer 10: conv4_3
            a10_spike, a10_readout = self.conv4_3(a9_spike, detach_grad, return_readout=True)
            a10_pooled = self.pool4(a10_spike, detach_grad)
            
            # Layer 11: conv5_1
            a11_spike, a11_readout = self.conv5_1(a10_pooled, detach_grad, return_readout=True)
            
            # Layer 12: conv5_2
            a12_spike, a12_readout = self.conv5_2(a11_spike, detach_grad, return_readout=True)
            
            # Layer 13: conv5_3
            a13_spike, a13_readout = self.conv5_3(a12_spike, detach_grad, return_readout=True)
            a13_pooled = self.pool5(a13_spike, detach_grad)
            
            # Layer 14: fc1
            a13_viewed = self.view(a13_pooled, detach_grad)
            a14_spike, a14_readout = self.fc1(a13_viewed, detach_grad, return_readout=True)
            
            # Layer 15: fc2 (output)
            a15 = self.fc2(a14_spike, detach_grad, act=False)
            
            # Return ALL 15 learnable layers + output (16 total)
            readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, a6_readout, a7_readout,
                              a8_readout, a9_readout, a10_readout, a11_readout, a12_readout, a13_readout, a14_readout, a15]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike, a7_spike,
                                a8_spike, a9_spike, a10_spike, a11_spike, a12_spike, a13_spike, a14_spike, a15]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)
        
        if use_prelif_for_loss:
            # Reverse pass following the correct offset correspondence pattern
            # Forward recording: [x, conv1_1_prelif, conv1_2_prelif, ..., conv5_3_prelif, fc1_prelif, fc2_output]
            # Reverse recording: [conv1_1.reverse_result, conv1_2.reverse_prelif, ..., fc1.reverse_prelif, fc2.reverse_prelif, target]

            # Reverse execution order: fc2.reverse -> fc1.reverse -> view.reverse -> pool5.reverse -> conv5_3.reverse -> ... -> conv1_1.reverse
            
            # fc2.reverse
            fc2_spike, fc2_reverse_prelif = self.fc2.reverse(target, detach_grad, return_prelif=True)
            
            # fc1.reverse 
            # fc1.reverse takes fc2_spike (4096) as input and outputs fc1_spike (512)
            # fc1_reverse_prelif is the input to fc1.reverse, which is fc2_spike (4096)
            fc1_spike, fc1_reverse_prelif = self.fc1.reverse(fc2_spike, detach_grad, return_prelif=True)
            
            # For proper gradient flow:
            # - fc1_spike (512-dim) should correspond to forward position 13
            # - fc2_spike (4096-dim) should correspond to forward position 14
            # - target corresponds to forward position 15
            
            # view.reverse
            fc1_viewed = self.view.reverse(fc1_spike, detach_grad)
            
            # pool5.reverse -> conv5_3.reverse
            conv5_3_input = self.pool5.reverse(fc1_viewed, detach_grad)
            conv5_3_spike, conv5_3_reverse_prelif = self.conv5_3.reverse(conv5_3_input, detach_grad, return_prelif=True)
            
            # conv5_2.reverse
            conv5_2_spike, conv5_2_reverse_prelif = self.conv5_2.reverse(conv5_3_spike, detach_grad, return_prelif=True)
            
            # conv5_1.reverse
            conv5_1_spike, conv5_1_reverse_prelif = self.conv5_1.reverse(conv5_2_spike, detach_grad, return_prelif=True)
            
            # pool4.reverse -> conv4_3.reverse
            conv4_3_input = self.pool4.reverse(conv5_1_spike, detach_grad)
            conv4_3_spike, conv4_3_reverse_prelif = self.conv4_3.reverse(conv4_3_input, detach_grad, return_prelif=True)
            
            # conv4_2.reverse
            conv4_2_spike, conv4_2_reverse_prelif = self.conv4_2.reverse(conv4_3_spike, detach_grad, return_prelif=True)
            
            # conv4_1.reverse
            conv4_1_spike, conv4_1_reverse_prelif = self.conv4_1.reverse(conv4_2_spike, detach_grad, return_prelif=True)
            
            # pool3.reverse -> conv3_3.reverse
            conv3_3_input = self.pool3.reverse(conv4_1_spike, detach_grad)
            conv3_3_spike, conv3_3_reverse_prelif = self.conv3_3.reverse(conv3_3_input, detach_grad, return_prelif=True)
            
            # conv3_2.reverse
            conv3_2_spike, conv3_2_reverse_prelif = self.conv3_2.reverse(conv3_3_spike, detach_grad, return_prelif=True)
            
            # conv3_1.reverse
            conv3_1_spike, conv3_1_reverse_prelif = self.conv3_1.reverse(conv3_2_spike, detach_grad, return_prelif=True)
            
            # pool2.reverse -> conv2_2.reverse
            conv2_2_input = self.pool2.reverse(conv3_1_spike, detach_grad)
            conv2_2_spike, conv2_2_reverse_prelif = self.conv2_2.reverse(conv2_2_input, detach_grad, return_prelif=True)
            
            # conv2_1.reverse
            conv2_1_spike, conv2_1_reverse_prelif = self.conv2_1.reverse(conv2_2_spike, detach_grad, return_prelif=True)
            
            # pool1.reverse -> conv1_2.reverse
            conv1_2_input = self.pool1.reverse(conv2_1_spike, detach_grad)
            conv1_2_spike, conv1_2_reverse_prelif = self.conv1_2.reverse(conv1_2_input, detach_grad, return_prelif=True)
            
            # conv1_1.reverse (final step, produces input reconstruction)
            input_reconstructed = self.conv1_1.reverse(conv1_2_spike, detach_grad, act=False)
            
            # Correct correspondence for VGG-16 with view layer:
            # The key is to handle the dimension mismatch caused by the view layer
            # Position 12: conv5_3_reverse_prelif (2x2x512) matches forward's a12_prelif (2x2x512)
            # Position 13: fc1_spike (512-dim) matches forward's a13_prelif_viewed (512-dim)
            # Position 14: fc2_spike (4096-dim) matches forward's a14_prelif (4096-dim) 
            # Position 15: target (10-dim) matches forward's a15 (10-dim)
            #
            # This ensures all 15 learnable layers get gradients:
            # - conv5_3.forward gets gradient from position 12 (conv5_3 prelif correspondence)
            # - fc1.backward gets gradient from position 13 (fc1_spike vs a13_prelif_viewed)
            # - fc2.backward gets gradient from position 14 (fc2_spike vs a14_prelif)
            prelif_features = [input_reconstructed, conv1_2_reverse_prelif, conv2_1_reverse_prelif, conv2_2_reverse_prelif, conv3_1_reverse_prelif, 
                             conv3_2_reverse_prelif, conv3_3_reverse_prelif, conv4_1_reverse_prelif, conv4_2_reverse_prelif, conv4_3_reverse_prelif,
                             conv5_1_reverse_prelif, conv5_2_reverse_prelif, conv5_3_reverse_prelif, fc1_spike, fc2_spike, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                conv1_1_spike = input_reconstructed  # conv1_1 produces input_reconstructed as spike
                spike_features = [input_reconstructed, conv1_1_spike, conv1_2_spike, conv2_1_spike, conv2_2_spike,
                                conv3_1_spike, conv3_2_spike, conv3_3_spike, conv4_1_spike, conv4_2_spike, conv4_3_spike,
                                conv5_1_spike, conv5_2_spike, conv5_3_spike, fc1_spike, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for ALL learnable layers (reverse order)
            
            # Layer 15 reverse: fc2 (target -> c15)  
            c15_spike, c15_readout = self.fc2.reverse(target, detach_grad, return_readout=True)
            
            # Layer 14 reverse: fc1 (c15 -> c14)
            c14_spike, c14_readout = self.fc1.reverse(c15_spike, detach_grad, return_readout=True)
            
            # View reverse
            c14_viewed = self.view.reverse(c14_spike, detach_grad)
            
            # Layer 13 reverse: conv5_3 (c14 -> c13)
            c13_unpooled = self.pool5.reverse(c14_viewed, detach_grad)
            c13_spike, c13_readout = self.conv5_3.reverse(c13_unpooled, detach_grad, return_readout=True)
            
            # Layer 12 reverse: conv5_2 (c13 -> c12)
            c12_spike, c12_readout = self.conv5_2.reverse(c13_spike, detach_grad, return_readout=True)
            
            # Layer 11 reverse: conv5_1 (c12 -> c11)
            c11_spike, c11_readout = self.conv5_1.reverse(c12_spike, detach_grad, return_readout=True)
            
            # Layer 10 reverse: conv4_3 (c11 -> c10)
            c10_unpooled = self.pool4.reverse(c11_spike, detach_grad)
            c10_spike, c10_readout = self.conv4_3.reverse(c10_unpooled, detach_grad, return_readout=True)
            
            # Layer 9 reverse: conv4_2 (c10 -> c9)
            c9_spike, c9_readout = self.conv4_2.reverse(c10_spike, detach_grad, return_readout=True)
            
            # Layer 8 reverse: conv4_1 (c9 -> c8)
            c8_spike, c8_readout = self.conv4_1.reverse(c9_spike, detach_grad, return_readout=True)
            
            # Layer 7 reverse: conv3_3 (c8 -> c7)
            c7_unpooled = self.pool3.reverse(c8_spike, detach_grad)
            c7_spike, c7_readout = self.conv3_3.reverse(c7_unpooled, detach_grad, return_readout=True)
            
            # Layer 6 reverse: conv3_2 (c7 -> c6)
            c6_spike, c6_readout = self.conv3_2.reverse(c7_spike, detach_grad, return_readout=True)
            
            # Layer 5 reverse: conv3_1 (c6 -> c5)
            c5_spike, c5_readout = self.conv3_1.reverse(c6_spike, detach_grad, return_readout=True)
            
            # Layer 4 reverse: conv2_2 (c5 -> c4)
            c4_unpooled = self.pool2.reverse(c5_spike, detach_grad)
            c4_spike, c4_readout = self.conv2_2.reverse(c4_unpooled, detach_grad, return_readout=True)
            
            # Layer 3 reverse: conv2_1 (c4 -> c3)
            c3_spike, c3_readout = self.conv2_1.reverse(c4_spike, detach_grad, return_readout=True)
            
            # Layer 2 reverse: conv1_2 (c3 -> c2)
            c2_unpooled = self.pool1.reverse(c3_spike, detach_grad)
            c2_spike, c2_readout = self.conv1_2.reverse(c2_unpooled, detach_grad, return_readout=True)
            
            # Layer 1 reverse: conv1_1 (c2 -> c1/input)
            c1 = self.conv1_1.reverse(c2_spike, detach_grad, act=False)
            
            # Return ALL 15 learnable layers + output (16 total)
            readout_features = [c1, c2_readout, c3_readout, c4_readout, c5_readout, c6_readout, c7_readout, c8_readout,
                              c9_readout, c10_readout, c11_readout, c12_readout, c13_readout, c14_readout, c15_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [c1, c2_spike, c3_spike, c4_spike, c5_spike, c6_spike, c7_spike, c8_spike,
                                c9_spike, c10_spike, c11_spike, c12_spike, c13_spike, c14_spike, c15_spike, target]
                return readout_features, spike_features
            
            return readout_features


class SpikingCNNModel_VGG8(nn.Module):
    """Spiking VGG-8 with readout - a shallower version of VGG-16"""
    def __init__(self, args):
        super(SpikingCNNModel_VGG8, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # VGG-8 structure (8 conv layers + 2 FC layers = 10 learnable layers):
        # Block 1: 2x Conv3x3(64) + MaxPool
        self.conv1_1 = SpikingC2ConvStride1(args, args.num_chn, 64, 3)
        self.conv1_2 = SpikingC2ConvStride1(args, 64, 64, 3)
        self.pool1 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 2: 2x Conv3x3(128) + MaxPool
        self.conv2_1 = SpikingC2ConvStride1(args, 64, 128, 3)
        self.conv2_2 = SpikingC2ConvStride1(args, 128, 128, 3)
        self.pool2 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 3: 2x Conv3x3(256) + MaxPool
        self.conv3_1 = SpikingC2ConvStride1(args, 128, 256, 3)
        self.conv3_2 = SpikingC2ConvStride1(args, 256, 256, 3)
        self.pool3 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 4: 2x Conv3x3(512) + MaxPool
        self.conv4_1 = SpikingC2ConvStride1(args, 256, 512, 3)
        self.conv4_2 = SpikingC2ConvStride1(args, 512, 512, 3)
        self.pool4 = SpikingC2Pool(args, 2, 2, 0)
        
        # View and fully connected layers
        self.view = SpikingC2View((512, 2, 2), 512*2*2)
        self.fc1 = SpikingC2Linear(args, 512*2*2, 1024)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.fc2 = SpikingC2Linear(args, 1024, output_dim)
        
        # Collect parameters
        self.forward_params = list()
        self.backward_params = list()
        
        conv_layers = [
            self.conv1_1, self.conv1_2, self.conv2_1, self.conv2_2,
            self.conv3_1, self.conv3_2, self.conv4_1, self.conv4_2,
            self.fc1, self.fc2
        ]
        
        for layer in conv_layers:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for EVERY learnable layer (10 layers total)
            # IMPORTANT: Apply pooling to prelif features that need it to match reverse spatial dimensions
            
            # Layer 1: conv1_1
            a1_spike, a1_prelif = self.conv1_1(x, detach_grad, return_prelif=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_prelif_raw = self.conv1_2(a1_spike, detach_grad, return_prelif=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a2_prelif = self.pool1(a2_prelif_raw, detach_grad) if a2_prelif_raw is not None else None
            
            # Layer 3: conv2_1
            a3_spike, a3_prelif = self.conv2_1(a2_pooled, detach_grad, return_prelif=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_prelif_raw = self.conv2_2(a3_spike, detach_grad, return_prelif=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a4_prelif = self.pool2(a4_prelif_raw, detach_grad) if a4_prelif_raw is not None else None
            
            # Layer 5: conv3_1
            a5_spike, a5_prelif = self.conv3_1(a4_pooled, detach_grad, return_prelif=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_prelif_raw = self.conv3_2(a5_spike, detach_grad, return_prelif=True)
            a6_pooled = self.pool3(a6_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a6_prelif = self.pool3(a6_prelif_raw, detach_grad) if a6_prelif_raw is not None else None
            
            # Layer 7: conv4_1
            a7_spike, a7_prelif = self.conv4_1(a6_pooled, detach_grad, return_prelif=True)
            
            # Layer 8: conv4_2
            a8_spike, a8_prelif_raw = self.conv4_2(a7_spike, detach_grad, return_prelif=True)
            a8_pooled = self.pool4(a8_spike, detach_grad)
            # Apply pool to prelif to match reverse spatial dimensions
            if a8_prelif_raw is not None:
                a8_prelif_pooled = self.pool4(a8_prelif_raw, detach_grad)
            else:
                a8_prelif_pooled = None
            
            # View layer (reshape to flatten)
            a8_viewed = self.view(a8_pooled, detach_grad)
            # Also apply view to conv4_2's prelif to get 2048-dim flattened
            if a8_prelif_pooled is not None:
                T, N = a8_prelif_pooled.shape[:2]
                a8_prelif_viewed = a8_prelif_pooled.view(T, N, -1)  # [T, N, 2048]
            else:
                a8_prelif_viewed = None
            
            # Layer 9: fc1
            a9_spike, a9_prelif = self.fc1(a8_viewed, detach_grad, return_prelif=True)
            
            # Layer 10: fc2 (output)
            a10 = self.fc2(a9_spike, detach_grad, act=False)
            
            # The issue: VGG-8 has a non-learnable view layer that creates a dimension mismatch
            # We handle it the same way as VGG-16
            prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif, a6_prelif, a7_prelif, 
                             a8_prelif_viewed, a9_prelif, a10]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike, a7_spike,
                                a8_spike, a9_spike, a10]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for EVERY learnable layer (10 layers total)
            # Layer 1: conv1_1
            a1_spike, a1_readout = self.conv1_1(x, detach_grad, return_readout=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_readout = self.conv1_2(a1_spike, detach_grad, return_readout=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            
            # Layer 3: conv2_1
            a3_spike, a3_readout = self.conv2_1(a2_pooled, detach_grad, return_readout=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_readout = self.conv2_2(a3_spike, detach_grad, return_readout=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            
            # Layer 5: conv3_1
            a5_spike, a5_readout = self.conv3_1(a4_pooled, detach_grad, return_readout=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_readout = self.conv3_2(a5_spike, detach_grad, return_readout=True)
            a6_pooled = self.pool3(a6_spike, detach_grad)
            
            # Layer 7: conv4_1
            a7_spike, a7_readout = self.conv4_1(a6_pooled, detach_grad, return_readout=True)
            
            # Layer 8: conv4_2
            a8_spike, a8_readout = self.conv4_2(a7_spike, detach_grad, return_readout=True)
            a8_pooled = self.pool4(a8_spike, detach_grad)
            
            # Layer 9: fc1
            a8_viewed = self.view(a8_pooled, detach_grad)
            a9_spike, a9_readout = self.fc1(a8_viewed, detach_grad, return_readout=True)
            
            # Layer 10: fc2 (output)
            a10 = self.fc2(a9_spike, detach_grad, act=False)
            
            # Return ALL 10 learnable layers + output (11 total)
            readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, a6_readout, a7_readout,
                              a8_readout, a9_readout, a10]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike, a7_spike,
                                a8_spike, a9_spike, a10]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)
        
        if use_prelif_for_loss:
            # Reverse pass following the correct offset correspondence pattern
            
            # fc2.reverse
            fc2_spike, fc2_reverse_prelif = self.fc2.reverse(target, detach_grad, return_prelif=True)
            
            # fc1.reverse 
            fc1_spike, fc1_reverse_prelif = self.fc1.reverse(fc2_spike, detach_grad, return_prelif=True)
            
            # view.reverse
            fc1_viewed = self.view.reverse(fc1_spike, detach_grad)
            
            # pool4.reverse -> conv4_2.reverse
            conv4_2_input = self.pool4.reverse(fc1_viewed, detach_grad)
            conv4_2_spike, conv4_2_reverse_prelif = self.conv4_2.reverse(conv4_2_input, detach_grad, return_prelif=True)
            
            # conv4_1.reverse
            conv4_1_spike, conv4_1_reverse_prelif = self.conv4_1.reverse(conv4_2_spike, detach_grad, return_prelif=True)
            
            # pool3.reverse -> conv3_2.reverse
            conv3_2_input = self.pool3.reverse(conv4_1_spike, detach_grad)
            conv3_2_spike, conv3_2_reverse_prelif = self.conv3_2.reverse(conv3_2_input, detach_grad, return_prelif=True)
            
            # conv3_1.reverse
            conv3_1_spike, conv3_1_reverse_prelif = self.conv3_1.reverse(conv3_2_spike, detach_grad, return_prelif=True)
            
            # pool2.reverse -> conv2_2.reverse
            conv2_2_input = self.pool2.reverse(conv3_1_spike, detach_grad)
            conv2_2_spike, conv2_2_reverse_prelif = self.conv2_2.reverse(conv2_2_input, detach_grad, return_prelif=True)
            
            # conv2_1.reverse
            conv2_1_spike, conv2_1_reverse_prelif = self.conv2_1.reverse(conv2_2_spike, detach_grad, return_prelif=True)
            
            # pool1.reverse -> conv1_2.reverse
            conv1_2_input = self.pool1.reverse(conv2_1_spike, detach_grad)
            conv1_2_spike, conv1_2_reverse_prelif = self.conv1_2.reverse(conv1_2_input, detach_grad, return_prelif=True)
            
            # conv1_1.reverse (final step, produces input reconstruction)
            input_reconstructed = self.conv1_1.reverse(conv1_2_spike, detach_grad, act=False)
            
            # Correct correspondence for VGG-8 with view layer:
            # Position 7: conv4_2_reverse_prelif matches forward's a7_prelif
            # Position 8: fc1_spike (2048-dim) matches forward's a8_prelif_viewed (2048-dim)
            # Position 9: fc2_spike (1024-dim) matches forward's a9_prelif (1024-dim) 
            # Position 10: target matches forward's a10
            prelif_features = [input_reconstructed, conv1_2_reverse_prelif, conv2_1_reverse_prelif, conv2_2_reverse_prelif,
                             conv3_1_reverse_prelif, conv3_2_reverse_prelif, conv4_1_reverse_prelif, conv4_2_reverse_prelif,
                             fc1_spike, fc2_spike, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                conv1_1_spike = input_reconstructed  # conv1_1 produces input_reconstructed as spike
                spike_features = [input_reconstructed, conv1_1_spike, conv1_2_spike, conv2_1_spike, conv2_2_spike,
                                conv3_1_spike, conv3_2_spike, conv4_1_spike, conv4_2_spike, fc1_spike, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for ALL learnable layers (reverse order)
            
            # Layer 10 reverse: fc2 (target -> c10)  
            c10_spike, c10_readout = self.fc2.reverse(target, detach_grad, return_readout=True)
            
            # Layer 9 reverse: fc1 (c10 -> c9)
            c9_spike, c9_readout = self.fc1.reverse(c10_spike, detach_grad, return_readout=True)
            
            # View reverse
            c9_viewed = self.view.reverse(c9_spike, detach_grad)
            
            # Layer 8 reverse: conv4_2 (c9 -> c8)
            c8_unpooled = self.pool4.reverse(c9_viewed, detach_grad)
            c8_spike, c8_readout = self.conv4_2.reverse(c8_unpooled, detach_grad, return_readout=True)
            
            # Layer 7 reverse: conv4_1 (c8 -> c7)
            c7_spike, c7_readout = self.conv4_1.reverse(c8_spike, detach_grad, return_readout=True)
            
            # Layer 6 reverse: conv3_2 (c7 -> c6)
            c6_unpooled = self.pool3.reverse(c7_spike, detach_grad)
            c6_spike, c6_readout = self.conv3_2.reverse(c6_unpooled, detach_grad, return_readout=True)
            
            # Layer 5 reverse: conv3_1 (c6 -> c5)
            c5_spike, c5_readout = self.conv3_1.reverse(c6_spike, detach_grad, return_readout=True)
            
            # Layer 4 reverse: conv2_2 (c5 -> c4)
            c4_unpooled = self.pool2.reverse(c5_spike, detach_grad)
            c4_spike, c4_readout = self.conv2_2.reverse(c4_unpooled, detach_grad, return_readout=True)
            
            # Layer 3 reverse: conv2_1 (c4 -> c3)
            c3_spike, c3_readout = self.conv2_1.reverse(c4_spike, detach_grad, return_readout=True)
            
            # Layer 2 reverse: conv1_2 (c3 -> c2)
            c2_unpooled = self.pool1.reverse(c3_spike, detach_grad)
            c2_spike, c2_readout = self.conv1_2.reverse(c2_unpooled, detach_grad, return_readout=True)
            
            # Layer 1 reverse: conv1_1 (c2 -> c1/input)
            c1 = self.conv1_1.reverse(c2_spike, detach_grad, act=False)
            
            # Return ALL 10 learnable layers + output (11 total)
            readout_features = [c1, c2_readout, c3_readout, c4_readout, c5_readout, c6_readout, c7_readout, c8_readout,
                              c9_readout, c10_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [c1, c2_spike, c3_spike, c4_spike, c5_spike, c6_spike, c7_spike, c8_spike,
                                c9_spike, c10_spike, target]
                return readout_features, spike_features
            
            return readout_features


class SpikingCNNModel_VGG6(nn.Module):
    """Spiking VGG-6 with readout - an even shallower version with only 6 conv layers"""
    def __init__(self, args):
        super(SpikingCNNModel_VGG6, self).__init__()
        
        # Store time steps
        self.time_steps = args.time_steps
        
        # VGG-6 structure (6 conv layers + 2 FC layers = 8 learnable layers):
        # Block 1: 2x Conv3x3(64) + MaxPool
        self.conv1_1 = SpikingC2ConvStride1(args, args.num_chn, 64, 3)
        self.conv1_2 = SpikingC2ConvStride1(args, 64, 64, 3)
        self.pool1 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 2: 2x Conv3x3(128) + MaxPool
        self.conv2_1 = SpikingC2ConvStride1(args, 64, 128, 3)
        self.conv2_2 = SpikingC2ConvStride1(args, 128, 128, 3)
        self.pool2 = SpikingC2Pool(args, 2, 2, 0)
        
        # Block 3: 2x Conv3x3(256) + MaxPool
        self.conv3_1 = SpikingC2ConvStride1(args, 128, 256, 3)
        self.conv3_2 = SpikingC2ConvStride1(args, 256, 256, 3)
        self.pool3 = SpikingC2Pool(args, 2, 2, 0)
        
        # View and fully connected layers
        self.view = SpikingC2View((256, 4, 4), 256*4*4)
        self.fc1 = SpikingC2Linear(args, 256*4*4, 512)
        
        # Determine number of classes
        if args.dataset == "MNIST" or args.dataset == "FashionMNIST" or args.dataset == "CIFAR10" or args.dataset == "SVHN" or args.dataset == "STL10_cls" or args.dataset == "MNIST_CNN" or args.dataset == "FashionMNIST_CNN":
            self.num_classes = 10
        elif args.dataset == "CIFAR100":
            self.num_classes = 100
        
        # Determine output dimension based on label encoding configuration
        self.use_label_encoding = getattr(args, 'use_label_encoding', False)
        if self.use_label_encoding:
            output_dim = getattr(args, 'encoding_dim', 128)
        else:
            output_dim = self.num_classes
        
        self.fc2 = SpikingC2Linear(args, 512, output_dim)
        
        # Collect parameters
        self.forward_params = list()
        self.backward_params = list()
        
        conv_layers = [
            self.conv1_1, self.conv1_2, self.conv2_1, self.conv2_2,
            self.conv3_1, self.conv3_2, self.fc1, self.fc2
        ]
        
        for layer in conv_layers:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params
    
    def forward(self, x, detach_grad=False, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Convert input to time-stepped format if needed
        if len(x.shape) == 4:  # [N, C, H, W]
            x = x.unsqueeze(0).repeat(self.time_steps, 1, 1, 1, 1)  # [T, N, C, H, W]
        
        if use_prelif_for_loss:
            # Forward pass - collect pre-LIF features for EVERY learnable layer (8 layers total)
            # IMPORTANT: Apply pooling to prelif features that need it to match reverse spatial dimensions
            
            # Layer 1: conv1_1
            a1_spike, a1_prelif = self.conv1_1(x, detach_grad, return_prelif=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_prelif_raw = self.conv1_2(a1_spike, detach_grad, return_prelif=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a2_prelif = self.pool1(a2_prelif_raw, detach_grad) if a2_prelif_raw is not None else None
            
            # Layer 3: conv2_1
            a3_spike, a3_prelif = self.conv2_1(a2_pooled, detach_grad, return_prelif=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_prelif_raw = self.conv2_2(a3_spike, detach_grad, return_prelif=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            # Apply pool to prelif to match reverse dimensions
            a4_prelif = self.pool2(a4_prelif_raw, detach_grad) if a4_prelif_raw is not None else None
            
            # Layer 5: conv3_1
            a5_spike, a5_prelif = self.conv3_1(a4_pooled, detach_grad, return_prelif=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_prelif_raw = self.conv3_2(a5_spike, detach_grad, return_prelif=True)
            a6_pooled = self.pool3(a6_spike, detach_grad)
            # Apply pool to prelif to match reverse spatial dimensions
            if a6_prelif_raw is not None:
                a6_prelif_pooled = self.pool3(a6_prelif_raw, detach_grad)
            else:
                a6_prelif_pooled = None
            
            # View layer (reshape to flatten)
            a6_viewed = self.view(a6_pooled, detach_grad)
            # Also apply view to conv3_2's prelif to get 4096-dim flattened
            if a6_prelif_pooled is not None:
                T, N = a6_prelif_pooled.shape[:2]
                a6_prelif_viewed = a6_prelif_pooled.view(T, N, -1)  # [T, N, 4096]
            else:
                a6_prelif_viewed = None
            
            # Layer 7: fc1
            a7_spike, a7_prelif = self.fc1(a6_viewed, detach_grad, return_prelif=True)
            
            # Layer 8: fc2 (output)
            a8 = self.fc2(a7_spike, detach_grad, act=False)
            
            # The issue: VGG-6 has a non-learnable view layer that creates a dimension mismatch
            # We handle it the same way as VGG-16/VGG-8
            prelif_features = [x, a1_prelif, a2_prelif, a3_prelif, a4_prelif, a5_prelif, a6_prelif_viewed, 
                             a7_prelif, a8]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike,
                                a7_spike, a8]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for EVERY learnable layer (8 layers total)
            # Layer 1: conv1_1
            a1_spike, a1_readout = self.conv1_1(x, detach_grad, return_readout=True)
            
            # Layer 2: conv1_2  
            a2_spike, a2_readout = self.conv1_2(a1_spike, detach_grad, return_readout=True)
            a2_pooled = self.pool1(a2_spike, detach_grad)
            
            # Layer 3: conv2_1
            a3_spike, a3_readout = self.conv2_1(a2_pooled, detach_grad, return_readout=True)
            
            # Layer 4: conv2_2
            a4_spike, a4_readout = self.conv2_2(a3_spike, detach_grad, return_readout=True)
            a4_pooled = self.pool2(a4_spike, detach_grad)
            
            # Layer 5: conv3_1
            a5_spike, a5_readout = self.conv3_1(a4_pooled, detach_grad, return_readout=True)
            
            # Layer 6: conv3_2
            a6_spike, a6_readout = self.conv3_2(a5_spike, detach_grad, return_readout=True)
            a6_pooled = self.pool3(a6_spike, detach_grad)
            
            # Layer 7: fc1
            a6_viewed = self.view(a6_pooled, detach_grad)
            a7_spike, a7_readout = self.fc1(a6_viewed, detach_grad, return_readout=True)
            
            # Layer 8: fc2 (output)
            a8 = self.fc2(a7_spike, detach_grad, act=False)
            
            # Return ALL 8 learnable layers + output (9 total)
            readout_features = [x, a1_readout, a2_readout, a3_readout, a4_readout, a5_readout, a6_readout,
                              a7_readout, a8]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [x, a1_spike, a2_spike, a3_spike, a4_spike, a5_spike, a6_spike,
                                a7_spike, a8]
                return readout_features, spike_features
            
            return readout_features
    
    def reverse(self, target, detach_grad=True, return_spikes_for_stats=False, use_prelif_for_loss=False):
        # Handle target format with label encoding support
        if self.use_label_encoding:
            # Label encoding mode: target should already be [T, num_classes, L] or can be passed directly
            if len(target.shape) == 3 and target.shape[0] == self.time_steps:
                # Already in correct [T, num_classes, L] format
                pass
            else:
                raise ValueError(f"In label encoding mode, target should have shape [T, num_classes, L], got {target.shape}")
        else:
            # Original mode: convert from label indices to one-hot
            if len(target.shape) == 1:
                target = F.one_hot(target, num_classes=self.num_classes).float().to(target.device)
            
            # Convert to time-stepped format
            if len(target.shape) == 2:
                target = target.unsqueeze(0).repeat(self.time_steps, 1, 1)
        
        if use_prelif_for_loss:
            # Reverse pass following the correct offset correspondence pattern
            
            # fc2.reverse
            fc2_spike, fc2_reverse_prelif = self.fc2.reverse(target, detach_grad, return_prelif=True)
            
            # fc1.reverse 
            fc1_spike, fc1_reverse_prelif = self.fc1.reverse(fc2_spike, detach_grad, return_prelif=True)
            
            # view.reverse
            fc1_viewed = self.view.reverse(fc1_spike, detach_grad)
            
            # pool3.reverse -> conv3_2.reverse
            conv3_2_input = self.pool3.reverse(fc1_viewed, detach_grad)
            conv3_2_spike, conv3_2_reverse_prelif = self.conv3_2.reverse(conv3_2_input, detach_grad, return_prelif=True)
            
            # conv3_1.reverse
            conv3_1_spike, conv3_1_reverse_prelif = self.conv3_1.reverse(conv3_2_spike, detach_grad, return_prelif=True)
            
            # pool2.reverse -> conv2_2.reverse
            conv2_2_input = self.pool2.reverse(conv3_1_spike, detach_grad)
            conv2_2_spike, conv2_2_reverse_prelif = self.conv2_2.reverse(conv2_2_input, detach_grad, return_prelif=True)
            
            # conv2_1.reverse
            conv2_1_spike, conv2_1_reverse_prelif = self.conv2_1.reverse(conv2_2_spike, detach_grad, return_prelif=True)
            
            # pool1.reverse -> conv1_2.reverse
            conv1_2_input = self.pool1.reverse(conv2_1_spike, detach_grad)
            conv1_2_spike, conv1_2_reverse_prelif = self.conv1_2.reverse(conv1_2_input, detach_grad, return_prelif=True)
            
            # conv1_1.reverse (final step, produces input reconstruction)
            input_reconstructed = self.conv1_1.reverse(conv1_2_spike, detach_grad, act=False)
            
            # Correct correspondence for VGG-6 with view layer:
            # Position 5: conv3_2_reverse_prelif matches forward's a5_prelif
            # Position 6: fc1_spike (4096-dim) matches forward's a6_prelif_viewed (4096-dim)
            # Position 7: fc2_spike (512-dim) matches forward's a7_prelif (512-dim) 
            # Position 8: target matches forward's a8
            prelif_features = [input_reconstructed, conv1_2_reverse_prelif, conv2_1_reverse_prelif, conv2_2_reverse_prelif,
                             conv3_1_reverse_prelif, conv3_2_reverse_prelif, fc1_spike, fc2_spike, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                conv1_1_spike = input_reconstructed  # conv1_1 produces input_reconstructed as spike
                spike_features = [input_reconstructed, conv1_1_spike, conv1_2_spike, conv2_1_spike, conv2_2_spike,
                                conv3_1_spike, conv3_2_spike, fc1_spike, target]
                return prelif_features, spike_features
            
            return prelif_features
        else:
            # Original mode: collect readout features for ALL learnable layers (reverse order)
            
            # Layer 8 reverse: fc2 (target -> c8)  
            c8_spike, c8_readout = self.fc2.reverse(target, detach_grad, return_readout=True)
            
            # Layer 7 reverse: fc1 (c8 -> c7)
            c7_spike, c7_readout = self.fc1.reverse(c8_spike, detach_grad, return_readout=True)
            
            # View reverse
            c7_viewed = self.view.reverse(c7_spike, detach_grad)
            
            # Layer 6 reverse: conv3_2 (c7 -> c6)
            c6_unpooled = self.pool3.reverse(c7_viewed, detach_grad)
            c6_spike, c6_readout = self.conv3_2.reverse(c6_unpooled, detach_grad, return_readout=True)
            
            # Layer 5 reverse: conv3_1 (c6 -> c5)
            c5_spike, c5_readout = self.conv3_1.reverse(c6_spike, detach_grad, return_readout=True)
            
            # Layer 4 reverse: conv2_2 (c5 -> c4)
            c4_unpooled = self.pool2.reverse(c5_spike, detach_grad)
            c4_spike, c4_readout = self.conv2_2.reverse(c4_unpooled, detach_grad, return_readout=True)
            
            # Layer 3 reverse: conv2_1 (c4 -> c3)
            c3_spike, c3_readout = self.conv2_1.reverse(c4_spike, detach_grad, return_readout=True)
            
            # Layer 2 reverse: conv1_2 (c3 -> c2)
            c2_unpooled = self.pool1.reverse(c3_spike, detach_grad)
            c2_spike, c2_readout = self.conv1_2.reverse(c2_unpooled, detach_grad, return_readout=True)
            
            # Layer 1 reverse: conv1_1 (c2 -> c1/input)
            c1 = self.conv1_1.reverse(c2_spike, detach_grad, act=False)
            
            # Return ALL 8 learnable layers + output (9 total)
            readout_features = [c1, c2_readout, c3_readout, c4_readout, c5_readout, c6_readout, c7_readout,
                              c8_readout, target]
            
            # Also return spike features for sparsity statistics if requested
            if return_spikes_for_stats:
                spike_features = [c1, c2_spike, c3_spike, c4_spike, c5_spike, c6_spike, c7_spike,
                                c8_spike, target]
                return readout_features, spike_features
            
            return readout_features